#!/usr/bin/env python
# -*- coding: UTF-8 -*-

"""
Provide coverage QC for assembled sequences:
1. plot paired-end reads as curves
2. plot base coverage and mate coverage
3. plot gaps in the sequence (if any)
"""
from collections import defaultdict

from ..apps.base import ActionDispatcher, logger, need_update, sh
from ..formats.base import BaseFile, must_open
from ..formats.bed import BedLine, sort
from ..formats.sizes import Sizes


class Coverage(BaseFile):
    """
    Three-column .coverage file, often generated by `genomeCoverageBed -d`
    contigID baseID coverage
    """

    def __init__(self, bedfile, sizesfile):

        bedfile = sort([bedfile])
        coveragefile = bedfile + ".coverage"
        if need_update(bedfile, coveragefile):
            cmd = "genomeCoverageBed"
            cmd += " -bg -i {0} -g {1}".format(bedfile, sizesfile)
            sh(cmd, outfile=coveragefile)

        self.sizes = Sizes(sizesfile).mapping

        filename = coveragefile
        assert filename.endswith(".coverage")
        super().__init__(filename)

    def get_plot_data(self, ctg, bins=None):
        import numpy as np

        from jcvi.algorithms.matrix import chunk_average

        fp = open(self.filename)
        size = self.sizes[ctg]

        data = np.zeros((size,), dtype=np.int)
        for row in fp:
            seqid, start, end, cov = row.split()
            if seqid != ctg:
                continue

            start, end = int(start), int(end)
            cov = int(cov)
            data[start:end] = cov

        bases = np.arange(1, size + 1)
        if bins:
            window = size / bins
            bases = bases[::window]
            data = chunk_average(data, window)

        return bases, data


def main():

    actions = (("posmap", "QC based on indexed posmap file"),)
    p = ActionDispatcher(actions)
    p.dispatch(globals())


def clone_name(s, ca=False):
    """
    >>> clone_name("120038881639")
    "0038881639"
    >>> clone_name("GW11W6RK01DAJDWa")
    "GW11W6RK01DAJDW"
    """
    if not ca:
        return s[:-1]

    if s[0] == "1":
        return s[2:]
    return s.rstrip("ab")


def bed_to_bedpe(
    bedfile, bedpefile, pairsbedfile=None, matesfile=None, ca=False, strand=False
):
    """
    This converts the bedfile to bedpefile, assuming the reads are from CA.
    """
    fp = must_open(bedfile)
    fw = must_open(bedpefile, "w")
    if pairsbedfile:
        fwpairs = must_open(pairsbedfile, "w")

    clones = defaultdict(list)
    for row in fp:
        b = BedLine(row)
        name = b.accn
        clonename = clone_name(name, ca=ca)
        clones[clonename].append(b)

    if matesfile:
        fp = open(matesfile)
        libraryline = next(fp)
        # 'library bes     37896   126916'
        lib, name, smin, smax = libraryline.split()
        assert lib == "library"
        smin, smax = int(smin), int(smax)
        logger.debug(
            "Happy mates for lib {0} fall between {1} - {2}".format(name, smin, smax)
        )

    nbedpe = 0
    nspan = 0
    for clonename, blines in clones.items():
        nlines = len(blines)
        if nlines == 2:
            a, b = blines
            aseqid, astart, aend = a.seqid, a.start, a.end
            bseqid, bstart, bend = b.seqid, b.start, b.end
            outcols = [aseqid, astart - 1, aend, bseqid, bstart - 1, bend, clonename]
            if strand:
                outcols.extend([0, a.strand, b.strand])
            print("\t".join(str(x) for x in outcols), file=fw)
            nbedpe += 1
        elif nlines == 1:
            (a,) = blines
            aseqid, astart, aend = a.seqid, a.start, a.end
            bseqid, bstart, bend = 0, 0, 0
        else:  # More than two lines per pair
            pass

        if pairsbedfile:
            start = min(astart, bstart) if bstart > 0 else astart
            end = max(aend, bend) if bend > 0 else aend
            if aseqid != bseqid:
                continue

            span = end - start + 1
            if (not matesfile) or (smin <= span <= smax):
                print(
                    "\t".join(str(x) for x in (aseqid, start - 1, end, clonename)),
                    file=fwpairs,
                )
                nspan += 1

    fw.close()
    logger.debug("A total of {0} bedpe written to `{1}`.".format(nbedpe, bedpefile))
    if pairsbedfile:
        fwpairs.close()
        logger.debug(
            "A total of {0} spans written to `{1}`.".format(nspan, pairsbedfile)
        )


if __name__ == "__main__":
    main()
